/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "execute_graph_builder.h"

ge::ExecuteNodeBuilder &AddData(ge::ExecuteGraphBuilder &builder, const char *name) {
  ge::OpDef data_op_def = {.type = builder.AddString("Data"),
                           .infer_shape_func = nullptr,
                           .tiling_func = nullptr,
                           .inputs_def =
                               {
                                   {ge::kMustIo, builder.AddString("x")},
                               },
                           .outputs_def = {
                               {ge::kMustIo, builder.AddString("y")},
                           }};
  return builder.AddNode()
      .SetOpDef(&data_op_def)
      .SetName(name)
      .SetInputName(0, "x", builder.BatchAddTensorDesc())
      .SetInputName(0, "y", builder.BatchAddTensorDesc());
}

ge::ExecuteNodeBuilder &AddTransData(ge::ExecuteGraphBuilder &builder, const char *name, ge::TensorDescPtr x_tensor) {
  ge::OpDef op_def = {.type = builder.AddString("TransData"),
                      .infer_shape_func = nullptr,
                      .tiling_func = nullptr,
                      .inputs_def =
                          {
                              {ge::kMustIo, builder.AddString("src")},
                          },
                      .outputs_def = {
                          {ge::kMustIo, builder.AddString("dst")},
                      }};
  return builder.AddNode()
      .SetOpDef(&op_def)
      .SetName(name)
      .SetInputName(0, "src", builder.BatchAddTensorDesc())
      .SetInputName(0, "dst", x_tensor);
}

ge::ExecuteNodeBuilder &AddMatMul(ge::ExecuteGraphBuilder &builder, const char *name, ge::TensorDescPtr x1_td,
                                  ge::TensorDescPtr x2_td) {
  ge::OpDef data_op_def = {.type = builder.AddString("MatMul"),
                           .infer_shape_func = nullptr,
                           .tiling_func = nullptr,
                           .inputs_def =
                               {
                                   {ge::kMustIo, builder.AddString("x1")},
                                   {ge::kMustIo, builder.AddString("x2")},
                                   {ge::kOptionalIo, builder.AddString("bias")},
                               },
                           .outputs_def = {
                               {ge::kMustIo, builder.AddString("y")},
                           }};
  return builder.AddNode()
      .SetOpDef(&data_op_def)
      .SetName(name)
      .SetInputName(0, "x1", builder.BatchAddTensorDesc())
      .SetInputName(1, "x2", builder.BatchAddTensorDesc())
      .SetInputName(0, "y", builder.BatchAddTensorDesc());
}

ge::ExecuteNodeBuilder &AddNetOutput(ge::ExecuteGraphBuilder &builder,
                                     std::initializer_list<ge::TensorDescPtr> input_tds) {
  ge::OpDef op_def = {.type = builder.AddString("NetOutput"),
                      .infer_shape_func = nullptr,
                      .tiling_func = nullptr,
                      .inputs_def =
                          {
                              {ge::kDynamicIo, builder.AddString("x")},
                          },
                      .outputs_def = {}};
  auto &node_builder = builder.AddNode().SetOpDef(&op_def).SetName("NetOutput");
  size_t i = 0;
  for (auto td : input_tds) {
    std::string input_name = "x" + std::to_string(i);
    node_builder.SetInputName(static_cast<int>(i), input_name.c_str(), td);
  }

  return node_builder;
}

int main() {
  ge::ExecuteGraphBuilder builder;
  auto &data1 = AddData(builder, "data1");
  auto &data2 = AddData(builder, "data2");
  auto &transdata1 = AddTransData(builder, "trans1", data1.GetOutputTensorDesc(0));
  auto &transdata2 = AddTransData(builder, "trans2", data2.GetOutputTensorDesc(0));
  auto &matmul = AddMatMul(builder, "matmul1", transdata1.GetOutputTensorDesc(0), transdata2.GetOutputTensorDesc(0));
  auto &netoutput = AddNetOutput(builder, {matmul.GetOutputTensorDesc(0)});
  auto graph = builder.AddEdge(data1, 0, transdata1, 0)
                   .AddEdge(data2, 0, transdata2, 0)
                   .AddEdge(transdata1, 0, matmul, 0)
                   .AddEdge(transdata2, 0, matmul, 1)
                   .AddEdge(matmul, 0, netoutput, 0)
                   .Build();

  // todo: 检查字符串的地址、TensorDesc的地址、OpDef的地址是否在pool内，以防意外用到pool外的内存
  // todo: 检查重名节点是否能够检查出来
  // todo: 检查TensorDesc是不是同一个
}